from .lstm import LSTM
from .cnn import CNN
from .transformer import Bert, Transformer
from .gru import GRU


class EncoderFactory():
    def __init__(self, model_name):
        self.model_name = model_name

    def get_encoder(self, input_dim, config):
        # encoder
        if 'lstm' in self.model_name:
            Encoder = LSTM
        elif 'cnn' in self.model_name:
            Encoder = CNN
        elif 'gru' == self.model_name:
            Encoder = GRU
        elif 'bert' == self.model_name:
            Encoder = Bert
        elif 'bert' in self.model_name:
            Encoder = Transformer
        else:
            raise ValueError(f'{self.model_name} not implemented.')
        return Encoder(input_dim, config)
